{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cross Entropy Loss 计算原理"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**torch.nn.CrossEntropyLoss 计算公式如下**\n",
    "$$ loss(x, class) = -log(\\frac{exp(x[class])}{\\Sigma_{j}{exp(x[j])}}) = -x[class] + log(\\Sigma_{j}{exp(x[j])})$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 二维数据上的计算与理解"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# input 3 records, 5 classes\n",
    "input = torch.randn(3, 5, requires_grad=True)\n",
    "\n",
    "# target 1x3 dimensions, 3 lables for 3 records\n",
    "target = torch.empty(3, dtype=torch.long).random_(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_array = input.detach().numpy()\n",
    "target_array = target.detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pytorch tools calculation: tensor(2.1582, grad_fn=<NllLossBackward>)\n",
      "My calculation: 2.158160789301556\n"
     ]
    }
   ],
   "source": [
    "loss_list = []\n",
    "\n",
    "## 总共有3个x\n",
    "for i in np.arange(3):\n",
    "    \n",
    "    ## 对应公式中 -x[class], i为第i个x，target_array[i]为对应的class的index\n",
    "    first = -input_array[i, target_array[i]]\n",
    "    second = 0\n",
    "    for j in np.arange(5):\n",
    "        \n",
    "        ## 对应公式中 指数累积求和\n",
    "        ## input_array[i, j]，针对第i个x，在第j个类别上的分布权重\n",
    "        second = second + np.exp(input_array[i, j])\n",
    "    second = np.log(second)\n",
    "    one_loss = first + second\n",
    "    \n",
    "    loss_list.append(one_loss)\n",
    "my_loss = np.mean(loss_list)\n",
    "pytorch_loss = loss(input, target)\n",
    "print(\"pytorch tools calculation:\", pytorch_loss)\n",
    "print(\"My calculation:\", my_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 高维度数据的计算过程与原理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# batch size 6, 3 classess, 5 widht x 5 height image\n",
    "input = torch.randn(6, 3, 5, 5, requires_grad=True)\n",
    "\n",
    "# batch size 6, 1 class label for 5 x 5 image each pixel \n",
    "target = torch.empty(6, 1, 5, 5, dtype=torch.long).random_(3)\n",
    "\n",
    "# input.shape, target.shape, input, target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_array = input.detach().numpy()\n",
    "target_array = target.detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pytorch tools calculation: tensor(1.5038, grad_fn=<NllLoss2DBackward>)\n",
      "My calculation: 1.5038311872782133\n"
     ]
    }
   ],
   "source": [
    "loss_list = []\n",
    "\n",
    "## 总共有epochs * r_i * c_i个x, 每个像素点都是一个x\n",
    "for epoch in np.arange(6):\n",
    "    for r_i in np.arange(5):\n",
    "        for c_i in np.arange(5):\n",
    "            ## 对应公式中 -x[class]\n",
    "            first = -input_array[epoch, target_array[epoch, 0, r_i, c_i], r_i, c_i]\n",
    "            \n",
    "            ## 对应公式中 sum(exp(x[j])), j 为类别个数\n",
    "            second = 0\n",
    "            for class_i in np.arange(3):\n",
    "                second = second + np.exp(input_array[epoch, class_i, r_i, c_i])\n",
    "            second = np.log(second)\n",
    "\n",
    "            one_loss = first + second\n",
    "            loss_list.append(one_loss)\n",
    "my_loss = np.mean(loss_list)\n",
    "\n",
    "pytorch_loss = loss(input, target.squeeze(1))\n",
    "\n",
    "print(\"pytorch tools calculation:\", pytorch_loss)\n",
    "print(\"My calculation:\", my_loss)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3-pytorch",
   "language": "python",
   "name": "py3-pytorch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
